{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0d6f6f6a-e5de-4f67-9829-7b49c91361c4",
   "metadata": {},
   "source": [
    "## Tools 案例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8da5d947-9589-406e-966f-0b6e1ef08d9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CompletionMessage(content=None, role='assistant', tool_calls=[CompletionMessageToolCall(id='call_9051234730097475169', function=Function(arguments='{\"date\": \"2024-01-01\", \"departure\": \"北京南站\", \"destination\": \"上海\"}', name='Ticket'), type='function', index=0)])\n"
     ]
    }
   ],
   "source": [
    "from pydantic import BaseModel, Field\n",
    "from typing import List\n",
    "from typing_extensions import Literal\n",
    "\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"Ticket\",\n",
    "            \"description\": \"根据用户提供的信息查询火车时刻\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"date\": {\n",
    "                        \"description\": \"要查询的火车日期\",\n",
    "                        \"title\": \"Date\",\n",
    "                        \"type\": \"string\",\n",
    "                    },\n",
    "                    \"departure\": {\n",
    "                        \"description\": \"出发城市或车站\",\n",
    "                        \"title\": \"Departure\",\n",
    "                        \"type\": \"string\",\n",
    "                    },\n",
    "                    \"destination\": {\n",
    "                        \"description\": \"要查询的火车日期\",\n",
    "                        \"title\": \"Destination\",\n",
    "                        \"type\": \"string\",\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"date\", \"departure\", \"destination\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"你能帮我查一下2024年1月1日从北京南站到上海的火车票吗？\"\n",
    "    }\n",
    "]\n",
    "\n",
    "from zhipuai import ZhipuAI\n",
    "client = ZhipuAI(api_key=\"填入你的key\")\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "    model=\"glm-4-plus\", # 请填写您要调用的模型名称\n",
    "    messages=messages,\n",
    "    tools=tools,\n",
    "    tool_choice=\"auto\",\n",
    ")\n",
    "print(response.choices[0].message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9248c980-45bf-4652-9ac8-d5b807788b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": Ticket.schema()['title'],\n",
    "            \"description\": Ticket.schema()['description'],\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": Ticket.schema()['properties'],\n",
    "                \"required\": Ticket.schema()['required'],\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83c55cec-6efe-4a03-a564-358a975cebdf",
   "metadata": {},
   "source": [
    "## Tools 自动解析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a9304278-fa93-4f8e-97c1-b0d8c5058029",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GLMAgent:\n",
    "    def __init__(self, model_name: str):\n",
    "        self.model_name = model_name\n",
    "\n",
    "    def call(self, user_prompt, response_model):\n",
    "        messages = [\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_prompt\n",
    "            }\n",
    "        ]\n",
    "        tools = [\n",
    "            {\n",
    "                \"type\": \"function\",\n",
    "                \"function\": {\n",
    "                    \"name\": response_model.schema()['title'],\n",
    "                    \"description\": response_model.schema()['description'],\n",
    "                    \"parameters\": {\n",
    "                        \"type\": \"object\",\n",
    "                        \"properties\": response_model.schema()['properties'],\n",
    "                        \"required\": response_model.schema()['required'],\n",
    "                    },\n",
    "                }\n",
    "            }\n",
    "        ]\n",
    "\n",
    "        response = client.chat.completions.create(\n",
    "            model=self.model_name,\n",
    "            messages=messages,\n",
    "            tools=tools,\n",
    "            tool_choice=\"auto\",\n",
    "        )\n",
    "        try:\n",
    "            arguments = response.choices[0].message.tool_calls[0].function.arguments\n",
    "            return response_model.model_validate_json(arguments)\n",
    "        except:\n",
    "            print('ERROR', response.choices[0].message)\n",
    "            return None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "77210fd6-ff6b-4b88-81c9-02ffe2832621",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Ticket(date='2024-01-01', departure='北京南站', destination='上海')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Ticket(BaseModel):\n",
    "    \"\"\"根据用户提供的信息查询火车时刻\"\"\"\n",
    "    date: str = Field(description=\"要查询的火车日期\")\n",
    "    departure: str = Field(description=\"出发城市或车站\")\n",
    "    destination: str = Field(description=\"要查询的火车日期\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call(\"你能帮我查一下2024年1月1日从北京南站到上海的火车票吗？\", Ticket)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4c27240-cdd3-43a8-abb1-4679b79aad08",
   "metadata": {},
   "source": [
    "## 中文分词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "e3c4b149-6980-4f7f-a57e-6a8c28961e1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(keyword=['阿水', '强哥', '好朋友', '谢大脚', '长贵', '老公'])"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"抽取句子中的的单词，进行文本分词\"\"\"\n",
    "    keyword: List[str] = Field(description=\"单词\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call('阿水是强哥的好朋友。谢大脚是长贵的老公。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90cb92fd-3835-4c87-a6c3-59f2df21eebe",
   "metadata": {},
   "source": [
    "## 文本分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "752c9c26-9235-4a18-9b24-d4a5e6758a37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(sentiment='正向')"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"分析文本的情感\"\"\"\n",
    "    sentiment: Literal[\"正向\", \"反向\"] = Field(description=\"情感类型\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-air').call('我今天很开心。', Text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "cc9f7f5c-47dd-4ed1-9c7c-fbf7d4d3e809",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(sentiment='postivate')"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"分析文本的情感\"\"\"\n",
    "    sentiment: Literal[\"postivate\", \"negative\"] = Field(description=\"情感类型\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-air').call('我今天很开心。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1a53df2-8002-4f2c-bb86-e357cf1e378d",
   "metadata": {},
   "source": [
    "## 文本匹配"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "de1461de-bee8-4316-ac00-88668d29e3e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(score=0.0)"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"判断句子是否语义相近\"\"\"\n",
    "    score:float = Field(description=\"文本相似度，介于0与1。0代表不相似，1代表相似\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-air').call('我今天很开心 与 我今天不开心。', Text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "5c106c28-a85b-4d05-a065-403412e3f228",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(score=0.8)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GLMAgent(model_name = 'glm-4-air').call('我今天很开心 与 我今天超级开心。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67c5fb94-edbb-402b-a4c6-53d8cf92b3f9",
   "metadata": {},
   "source": [
    "## 实体抽取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "7c110007-78f0-4e43-87f9-4a493023a21f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(person=['我', '徐也也', '强哥'], location=['海淀'])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"抽取实体\"\"\"\n",
    "    person: List[str] = Field(description=\"人名\")\n",
    "    location: List[str] = Field(description=\"地名\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-air').call('今天我和徐也也去海淀吃饭，强哥也去了。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "780993f3-e07b-47de-a4ba-7c282f247b23",
   "metadata": {},
   "source": [
    "## 关系抽取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "30e9ee4e-a92d-4fe2-9986-fe726d7f1668",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(source_person=['强哥'], target_person=['阿水'], relationship=['朋友'])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"抽取句子中所有实体之间的关系\"\"\"\n",
    "    source_person: List[str] = Field(description=\"原始实体\")\n",
    "    target_person: List[str] = Field(description=\"目标实体\")\n",
    "    relationship: List[Literal[\"朋友\", \"亲人\", \"同事\"]] = Field(description=\"待选关系\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-air').call('阿水是强哥的好朋友。谢大脚是长贵的老公。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35d4df79-72e1-405a-91cc-a5d6c786b4e7",
   "metadata": {},
   "source": [
    "## 文本摘要"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "367c54cf-4c02-416e-9f7f-318e108996fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(abstract='中国探月工程20年发展，连续成功任务刷新世界纪录，总书记肯定成就。')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"抽取句子的摘要\"\"\"\n",
    "    abstract: str = Field(description=\"摘要结果\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call(\"20年来，中国探月工程从无到有、从小到大、从弱到强。党的十八大后，一个个探月工程任务连续成功，不断刷新世界月球探测史的中国纪录嫦娥三号实现我国探测器首次地外天体软着陆和巡视探测，总书记肯定“在人类攀登科技高峰征程中刷新了中国高度”；\", Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ec8ea1e-a511-45d5-90b8-20b63e315f06",
   "metadata": {},
   "source": [
    "## 关键词提取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "55e6c610-eada-4d10-be5d-cf41a62a48ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(keyword='阿水')"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"抽取句子中的关键词\"\"\"\n",
    "    keyword: str = Field(description=\"关键词\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call('阿水是强哥的好朋友。谢大脚是长贵的老公。', Text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b51af1d1-f9d3-4917-8ab5-1f90381a6a50",
   "metadata": {},
   "source": [
    "## 复杂案例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "8f34ceb1-5207-4b4a-8a3a-c847e1c39f61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(search=True, keywords=['汽车发动故障', '轮胎故障', '处理方法'], intent='查询产品问题')"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"文本问答内容解析\"\"\"\n",
    "    search: bool = Field(description=\"是否需要搜索\")\n",
    "    keywords: List[str] = Field(description=\"待选关键词\")\n",
    "    intent: Literal[\"查询客服问题\", \"查询产品问题\", \"查询系统问题\", \"其他\"] = Field(description=\"意图\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call('汽车发动和轮胎出故障了，如何处理？', Text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "1e6c723e-56e4-4341-a220-427ec78a6d2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(time=['2月8日上午'], particate=['谷爱凌'], competition=['北京冬奥会自由式滑雪女子大跳台决赛', '2022语言与智能技术竞赛'])"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Text(BaseModel):\n",
    "    \"\"\"文本问答内容解析\"\"\"\n",
    "    time: List[str] = Field(description=\"时间\")\n",
    "    particate: List[str] = Field(description=\"选手\")\n",
    "    competition: List[str] = Field(description=\"赛事名称\")\n",
    "\n",
    "GLMAgent(model_name = 'glm-4-plus').call('2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌！2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办。', Text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "094e4cfb-d519-4c2d-b3db-ebf3c60e7cfc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
